{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example Usage of L-GCN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython\n",
    "import pickle\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from models.LGCN import LGCN\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = pickle.load(open(\"data/1hop_500\", \"rb\"))\n",
    "\n",
    "# normalization\n",
    "val,pos = dataset.x.max(dim=0)\n",
    "dataset.x /= val.abs()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Edge Module\n",
    "\n",
    "Create the learning mechanism that is to operate on the edge populations / multi-edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Reshape(torch.nn.Module):\n",
    "    def __init__(self, *args):\n",
    "        super(Reshape, self).__init__()\n",
    "        self.shape = args\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x.view(self.shape)\n",
    "\n",
    "\n",
    "def mm_CONV(conv_channels=20, out_channels=4):\n",
    "    return torch.nn.Sequential(\n",
    "        torch.nn.Conv1d(2, conv_channels, kernel_size=3, stride=1, padding=1),\n",
    "        torch.nn.AdaptiveMaxPool1d(1),\n",
    "        torch.nn.ReLU(),\n",
    "        Reshape(-1, conv_channels),\n",
    "        torch.nn.Linear(conv_channels, 2*out_channels),\n",
    "        torch.nn.Dropout(p=0.2),\n",
    "        torch.nn.ReLU(),\n",
    "        torch.nn.Linear(2*out_channels, out_channels),\n",
    "        torch.nn.ReLU()\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Example Nets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Parameter `L` determines size of latent representations of the edge populations. Parameter `H1` determines representation size in the intermediate node embedding layer. `H2` determines the size of the final output layer and should agree with the downstream task configured in the data set.\n",
    "\n",
    "In the GCN layers, the following controls are available:\n",
    "* `make_bidirectional` offers bidirectional propagation over directed graphs\n",
    "* `neighbor_nl` offers additional per-neighbor nonlinearity *inside* the graph convolution (L-GCN+)\n",
    "* `DVE` provides the option of embedding local neighborhood aggregations of the latent representations (mean-pool) directly on the nodes, before proceeding with the GCN (L-GCN+ & DVE)\n",
    "\n",
    "\n",
    "In these examples, edge populations are pre-padded with zeros and sorted by original sequence length, accompanied by batch cut-offs for faster processing. The `edge_attr_cutoffs` parameter may be omitted to proceed without batching.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### L4-GCN (bidirectional propagation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LGCN_Net(torch.nn.Module):\n",
    "    def __init__(self, L=4, H1=20, H2=2):\n",
    "        super().__init__()\n",
    "        self.conv1 = LGCN(dataset.num_features, H1, mm_CONV(out_channels=L), L=L,\n",
    "                          make_bidirectional=True)\n",
    "        self.conv2 = LGCN(H1, H2, mm_CONV(out_channels=L), L=L,\n",
    "                          make_bidirectional=True)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index, edge_attr, edge_attr_cutoffs = data.x, data.edge_index, data.edge_attr, data.edge_attr_cutoffs\n",
    "        x = self.conv1(x, edge_index, edge_attr, edge_attr_cutoffs=edge_attr_cutoffs)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index, edge_attr, edge_attr_cutoffs=edge_attr_cutoffs)\n",
    "\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### L4-GCN+ (bidirectional propagation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LGCN_Net2(torch.nn.Module):\n",
    "    def __init__(self, L=4, H1=20, H2=2):\n",
    "        super().__init__()\n",
    "        self.conv1 = LGCN(dataset.num_features, H1, mm_CONV(out_channels=L),L=L,\n",
    "                          make_bidirectional=True,\n",
    "                          neighbor_nl=True)\n",
    "        self.conv2 = LGCN(H1, H2, mm_CONV(out_channels=L), L=L,\n",
    "                          make_bidirectional=True,\n",
    "                          neighbor_nl=True)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index, edge_attr, edge_attr_cutoffs = data.x, data.edge_index, data.edge_attr, data.edge_attr_cutoffs\n",
    "        x = self.conv1(x, edge_index, edge_attr, edge_attr_cutoffs=edge_attr_cutoffs)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index, edge_attr, edge_attr_cutoffs=edge_attr_cutoffs)\n",
    "\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### L4-GCN+ & DVE (bidirectional propagation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LGCN_Net3(torch.nn.Module):\n",
    "    def __init__(self, L=4, H1=20, H2=2):\n",
    "        super().__init__()\n",
    "        self.conv1 = LGCN(dataset.num_features, H1, mm_CONV(out_channels=L), L=L,\n",
    "                          make_bidirectional=True,\n",
    "                          neighbor_nl=True,\n",
    "                          DVE=True)\n",
    "        self.conv2 = LGCN(H1, H2, mm_CONV(out_channels=L), L=L,\n",
    "                          make_bidirectional=True,\n",
    "                          neighbor_nl=True)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index, edge_attr, edge_attr_cutoffs = data.x, data.edge_index, data.edge_attr, data.edge_attr_cutoffs\n",
    "        x = self.conv1(x, edge_index, edge_attr, edge_attr_cutoffs=edge_attr_cutoffs)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index, edge_attr, edge_attr_cutoffs=edge_attr_cutoffs)\n",
    "\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Done."
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Accuracy: 0.7958\n",
      "Number of parameters: 8930\n"
     ]
    }
   ],
   "source": [
    "epochs = 100\n",
    "lr = 5e-4\n",
    "weight_decay = 5e-4\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = LGCN_Net3().to(device)\n",
    "data = dataset.to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "model.train()\n",
    "\n",
    "train_class_ratio = dataset.y[dataset.train_mask].sum().item()/dataset.y[dataset.train_mask].shape[0]\n",
    "train_class_weights = torch.Tensor([train_class_ratio,1-train_class_ratio]).to(device)\n",
    "\n",
    "out = display(IPython.display.Pretty('Starting'), display_id=True)\n",
    "for epoch in range(epochs):\n",
    "    optimizer.zero_grad()\n",
    "    loss = F.nll_loss(model(data)[data.train_mask], data.y[data.train_mask], weight=train_class_weights)\n",
    "    loss.backward()\n",
    "    optimizer.step()    \n",
    "    out.update(IPython.display.Pretty(f\"Epoch {epoch+1}/{epochs}\"))\n",
    "\n",
    "out.update(IPython.display.Pretty(\"Done.\"))\n",
    "model.eval()\n",
    "\n",
    "test_acc = model(data).max(dim=1)[1][data.test_mask].eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()\n",
    "print('Testing Accuracy: {:.4f}'.format(test_acc))\n",
    "print('Number of parameters: {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
